{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "262Td2bPGm8L",
        "outputId": "e955a49d-00f9-4ebf-ca7c-b5c2957336f3"
      },
      "outputs": [],
      "source": [
        "!pip install --upgrade --user transformers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OwAkTJcAJnkG"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from transformers.file_utils import is_tf_available, is_torch_available, is_torch_tpu_available\n",
        "from transformers import BertTokenizerFast, BertForSequenceClassification\n",
        "from transformers import Trainer, TrainingArguments\n",
        "import numpy as np\n",
        "import random\n",
        "from sklearn.datasets import fetch_20newsgroups\n",
        "from sklearn.model_selection import train_test_split"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sFv-FiYtKuuf"
      },
      "outputs": [],
      "source": [
        "def set_seed(seed: int):\n",
        "    \"\"\"\n",
        "    Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if\n",
        "    installed).\n",
        "\n",
        "    Args:\n",
        "        seed (:obj:`int`): The seed to set.\n",
        "    \"\"\"\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    if is_torch_available():\n",
        "        torch.manual_seed(seed)\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "        # ^^ safe to call this function even if cuda is not available\n",
        "    if is_tf_available():\n",
        "        import tensorflow as tf\n",
        "\n",
        "        tf.random.set_seed(seed)\n",
        "\n",
        "set_seed(1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gDRivaE1KYWA"
      },
      "outputs": [],
      "source": [
        "# the model we gonna train, base uncased BERT\n",
        "# check text classification models here: https://huggingface.co/models?filter=text-classification\n",
        "model_name = \"bert-base-uncased\"\n",
        "# max sequence length for each document/sentence sample\n",
        "max_length = 512"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HWtlkkVKXHc0"
      },
      "outputs": [],
      "source": [
        "# load the tokenizer\n",
        "tokenizer = BertTokenizerFast.from_pretrained(model_name, do_lower_case=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aSanWgrgXhbG"
      },
      "outputs": [],
      "source": [
        "def read_20newsgroups(test_size=0.2):\n",
        "  # download & load 20newsgroups dataset from sklearn's repos\n",
        "  dataset = fetch_20newsgroups(subset=\"all\", shuffle=True, remove=(\"headers\", \"footers\", \"quotes\"))\n",
        "  documents = dataset.data\n",
        "  labels = dataset.target\n",
        "  # split into training & testing a return data as well as label names\n",
        "  return train_test_split(documents, labels, test_size=test_size), dataset.target_names\n",
        "  \n",
        "# call the function\n",
        "(train_texts, valid_texts, train_labels, valid_labels), target_names = read_20newsgroups()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E_r0ECJaXO-y"
      },
      "outputs": [],
      "source": [
        "# tokenize the dataset, truncate when passed `max_length`, \n",
        "# and pad with 0's when less than `max_length`\n",
        "train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=max_length)\n",
        "valid_encodings = tokenizer(valid_texts, truncation=True, padding=True, max_length=max_length)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7riiOpJiXdTd"
      },
      "outputs": [],
      "source": [
        "class NewsGroupsDataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, encodings, labels):\n",
        "        self.encodings = encodings\n",
        "        self.labels = labels\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}\n",
        "        item[\"labels\"] = torch.tensor([self.labels[idx]])\n",
        "        return item\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.labels)\n",
        "\n",
        "# convert our tokenized data into a torch Dataset\n",
        "train_dataset = NewsGroupsDataset(train_encodings, train_labels)\n",
        "valid_dataset = NewsGroupsDataset(valid_encodings, valid_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "I4aAwDGZXnyk",
        "outputId": "ea1f6f93-940f-43b0-bb9d-317662cb75a7"
      },
      "outputs": [],
      "source": [
        "# load the model and pass to CUDA\n",
        "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=len(target_names)).to(\"cuda\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AkZN1X-zOoe0"
      },
      "outputs": [],
      "source": [
        "from sklearn.metrics import accuracy_score\n",
        "\n",
        "def compute_metrics(pred):\n",
        "  labels = pred.label_ids\n",
        "  preds = pred.predictions.argmax(-1)\n",
        "  # calculate accuracy using sklearn's function\n",
        "  acc = accuracy_score(labels, preds)\n",
        "  return {\n",
        "      'accuracy': acc,\n",
        "  }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "MjMB9f01Pyiu",
        "outputId": "7654f4a7-b323-49ef-e5b8-eca3d8c00369"
      },
      "outputs": [],
      "source": [
        "training_args = TrainingArguments(\n",
        "    output_dir='./results',          # output directory\n",
        "    num_train_epochs=3,              # total number of training epochs\n",
        "    per_device_train_batch_size=8,  # batch size per device during training\n",
        "    per_device_eval_batch_size=20,   # batch size for evaluation\n",
        "    warmup_steps=500,                # number of warmup steps for learning rate scheduler\n",
        "    weight_decay=0.01,               # strength of weight decay\n",
        "    logging_dir='./logs',            # directory for storing logs\n",
        "    load_best_model_at_end=True,     # load the best model when finished training (default metric is loss)\n",
        "    # but you can specify `metric_for_best_model` argument to change to accuracy or other metric\n",
        "    logging_steps=400,               # log & save weights each logging_steps\n",
        "    save_steps=400,\n",
        "    evaluation_strategy=\"steps\",     # evaluate each `logging_steps`\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BukYQXs2P35S"
      },
      "outputs": [],
      "source": [
        "trainer = Trainer(\n",
        "    model=model,                         # the instantiated Transformers model to be trained\n",
        "    args=training_args,                  # training arguments, defined above\n",
        "    train_dataset=train_dataset,         # training dataset\n",
        "    eval_dataset=valid_dataset,          # evaluation dataset\n",
        "    compute_metrics=compute_metrics,     # the callback that computes metrics of interest\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 357
        },
        "id": "s5a7QY_wP5iD",
        "outputId": "d8c1b2f2-5c07-414f-f2b7-45b9b406faf2"
      },
      "outputs": [],
      "source": [
        "# train the model\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 87
        },
        "id": "cdWlnZxAR0XA",
        "outputId": "4e40738e-5500-4567-f107-e5c019506cf7"
      },
      "outputs": [],
      "source": [
        "# evaluate the current model after training\n",
        "trainer.evaluate()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ojgBg7cfSp4J",
        "outputId": "234d385e-fa7f-483d-fb40-e37ce8652c3a"
      },
      "outputs": [],
      "source": [
        "# saving the fine tuned model & tokenizer\n",
        "model_path = \"20newsgroups-bert-base-uncased\"\n",
        "model.save_pretrained(model_path)\n",
        "tokenizer.save_pretrained(model_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KXHXb7RYTARC"
      },
      "outputs": [],
      "source": [
        "model = BertForSequenceClassification.from_pretrained(model_path, num_labels=len(target_names)).to(\"cuda\")\n",
        "tokenizer = BertTokenizerFast.from_pretrained(model_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "96uWKthsR8fS"
      },
      "outputs": [],
      "source": [
        "def get_prediction(text):\n",
        "    # prepare our text into tokenized sequence\n",
        "    inputs = tokenizer(text, padding=True, truncation=True, max_length=max_length, return_tensors=\"pt\").to(\"cuda\")\n",
        "    # perform inference to our model\n",
        "    outputs = model(**inputs)\n",
        "    # get output probabilities by doing softmax\n",
        "    probs = outputs[0].softmax(1)\n",
        "    # executing argmax function to get the candidate label\n",
        "    return target_names[probs.argmax()]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EUaasmhpSYcg",
        "outputId": "8bb7b6ce-5ae8-4d5c-d8e5-43d42f5a136f"
      },
      "outputs": [],
      "source": [
        "# Example #1\n",
        "text = \"\"\"With the pace of smartphone evolution moving so fast, there's always something waiting in the wings. \n",
        "No sooner have you spied the latest handset, that there's anticipation for the next big thing. \n",
        "Here we look at those phones that haven't yet launched, the upcoming phones for 2021. \n",
        "We'll be updating this list on a regular basis, with those device rumours we think are credible and exciting.\"\"\"\n",
        "print(get_prediction(text))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ht6zywR_rOzA"
      },
      "outputs": [],
      "source": [
        "# Example #2\n",
        "text = \"\"\"\n",
        "A black hole is a place in space where gravity pulls so much that even light can not get out. \n",
        "The gravity is so strong because matter has been squeezed into a tiny space. This can happen when a star is dying.\n",
        "Because no light can get out, people can't see black holes. \n",
        "They are invisible. Space telescopes with special tools can help find black holes. \n",
        "The special tools can see how stars that are very close to black holes act differently than other stars.\n",
        "\"\"\"\n",
        "print(get_prediction(text))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Example #3\n",
        "text = \"\"\"\n",
        "Coronavirus disease (COVID-19) is an infectious disease caused by a newly discovered coronavirus.\n",
        "Most people infected with the COVID-19 virus will experience mild to moderate respiratory illness and recover without requiring special treatment.  \n",
        "Older people, and those with underlying medical problems like cardiovascular disease, diabetes, chronic respiratory disease, and cancer are more likely to develop serious illness.\n",
        "\"\"\"\n",
        "print(get_prediction(text))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "machine_shape": "hm",
      "name": "Untitled21.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3.8.7 64-bit",
      "metadata": {
        "interpreter": {
          "hash": "777490da48e046e3b512f0b24bf037db286a787493a11bf82a9e0f2cbf21bb67"
        }
      },
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
